//+------------------------------------------------------------------+
//|                                     dimension reduction test.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "https://www.mql5.com/en/users/omegajoctan"
#property version   "1.00"
#property script_show_inputs

#include <TruncatedSVD.mqh>
#include <NMF.mqh>
#include <matrix_utils.mqh>
#include <Linear Regression.mqh>
#include <metrics.mqh>

CMatrixutils matrix_utils;
CTruncatedSVD *truncated_svd;
CNMF *nmf;

CLinearRegression lr;
CMetrics metrics;

input group "Dimensionality Reduction Inputs"
input int nmf_iterations = 30;
enum dim_red {NMF, TRUNC_SVD, None};
input dim_red dimension_redux = TRUNC_SVD;

#define INDICATORS 38

input int buffer_size = 100;

struct buffer_info //Carry all the buffer information and values 
  {
    matrix values;
    string names;
    
    int GetNames(string names_str, string &NamesArr[])
      {
       ushort sep = StringGetCharacter(",", 0);
       return StringSplit(names_str,sep,NamesArr);
      }
  };

struct buffers_struct:public buffer_info
  {   
  
   buffer_info _CopyRates(long handle, string buff_names, int start, int count) //Copy all buffers given available buffers from an indicatpr
     { 
       string names_arr[];
       
       int total_buffers = GetNames(buff_names, names_arr); //Buffer names == total buffers | this function extracts the number of buffers
       
        buffer_info buffers;
        buffers.values.Resize(count, total_buffers);
        
        vector v(count);
        v.Fill(0.0);
        
         for (int i=0; i<total_buffers; i++)
          {
           if (!v.CopyIndicatorBuffer(handle, i, start, count))
             {
               printf("Failed to copy %s Buffer %d Err = %d ",names_arr[i],i,GetLastError());
               continue;
             }
            buffers.values.Col(v, i); //Store each buffer into a values matrix
          } 
        return buffers;
      };
  };   
  
struct Indicators 
  {
    long handle[];
    string name[];
        
    buffers_struct buffers[];
      
    Indicators::Indicators(string symbol,ENUM_TIMEFRAMES timeframe)
     {
        ArrayResize(handle, INDICATORS);
        ArrayResize(name, INDICATORS);
        ArrayResize(buffers, INDICATORS);
        
        
        TesterHideIndicators(true);
        
        // Trend following(13)
        name[0] = "Adaptive Moving Averate"; 
        name[1] = "Average Directional Movement Index";
        name[2] = "Average Directional Movement Index Wilder";
        name[3] = "Bollinger Bands";
        name[4] = "Double Exponential Moving Average";
        name[5] = "Envelopes";
        name[6] = "Fractal Adaptive Moving Average";
        name[7] = "Ichimoku Kinko Hyo";
        name[8] = "Moving Average";
        name[9] = "Parabolic SAR";
        name[10] = "Standard Deviation";
        name[11] = "Tripple Exponential Moving Average";
        name[12] = "Variable Index Dynamic Average";

        // Oscillators (15)
        name[13] = "Average True Range";
        name[14] = "Bears Power";
        name[15] = "Bulls Power";
        name[16] = "Chainkin Oscillator";
        name[17] = "Commodity Channel Index";
        name[18] = "De Marker";
        name[19] = "Force Index";
        name[20] = "MACD";
        name[21] = "Momentum";
        name[22] = "Moving Average of Oscillator";
        name[23] = "Relative Strength Index";
        name[24] = "Relative Vigor Index";
        name[25] = "Stochastic Oscillator";
        name[26] = "Tripple Exponential Average";
        name[27] = "Williams' Percent Range";

        // Volumes(4)
        name[28] = "Accumulator Distributor";
        name[29] = "Money Flow Index";
        name[30] = "On Balance Volume";
        name[31] = "Volumes";

        // Bill Williams(6)
        name[32] = "Accelerator Oscillator";
        name[33] = "Alligator";
        name[34] = "Awesome Oscillator";
        name[35] = "Fractals";
        name[36] = "Gator Oscillator";
        name[37] = "Market Facilitation Index";
        
//--- Declaring and assigning the handles
        
//--- Trend
   
       handle[0] = iAMA(symbol, timeframe, 9 , 2 , 30, 0, PRICE_OPEN);
       buffers[0].names = " AMA";
       
       handle[1] = iADX(symbol, timeframe, 14);
       buffers[1].names = " ADX-MAIN_LINE, ADX-PLUSDI_LINE, ADX-MINUSDI_LINE";
       
       handle[2] = iADXWilder(symbol, timeframe, 14);
       buffers[2].names = " ADXWilder-MAIN_LINE, ADXWilder-PLUSDI_LINE, ADXWilder-MINUSDI_LINE";
       
       handle[3] = iBands(symbol, timeframe, 20, 0, 2.0, PRICE_OPEN);
       buffers[3].names = " BB-BASE_LINE, BB-UPPER_BAND, BB-LOWER_BAND";
       
       handle[4] = iDEMA(symbol, timeframe, 14, 0, PRICE_OPEN);
       buffers[4].names = " DEMA";
       
       handle[5] = iEnvelopes(symbol, timeframe, 14, 0, MODE_SMA, PRICE_OPEN, 0.1);
       buffers[5].names = " Envelopes-UPPER_LINE, Envelopes-LOWER_LINE";
       
       handle[6] = iFrAMA(symbol, timeframe, 14, 0, PRICE_OPEN);
       buffers[6].names = " FRAMA";
       
       handle[7] = iIchimoku(symbol, timeframe, 9, 26, 52);
       buffers[7].names = " Ichimoku-TENKANSEN_LINE, Ichimoku-KIJUNSEN_LINE, Ichimoku-SENKOUSPANA_LINE, Ichimoku-SENKOUSPANB_LINE, Ichimoku-CHIKOUSPAN_LINE";
       
       handle[8] = iMA(symbol, timeframe, 10, 0, MODE_SMA, PRICE_OPEN);
       buffers[8].names = " MA";
       
       handle[9] = iSAR(symbol, timeframe, 0.02, 0.2);
       buffers[9].names = " SAR";
       
       handle[10] = iStdDev(symbol, timeframe, 10000, 0, MODE_SMA, PRICE_OPEN);
       buffers[10].names = " StdDev";
       
       handle[11] = iTEMA(symbol, timeframe, 14, 0, PRICE_CLOSE);
       buffers[11].names = " TEMA";
       
       
       handle[12] = iVIDyA(symbol, timeframe, 9, 12, 0, PRICE_OPEN);
       buffers[12].names = " ViDyA";
       
//--- Oscillators
      
      handle[13] = iATR(symbol, timeframe, 14);
      buffers[13].names = " ATR";
      
      handle[14] = iBearsPower(symbol, timeframe, 13);
      buffers[14].names = " BearsPower";
      
      handle[15] = iBullsPower(symbol, timeframe, 13);
      buffers[15].names = " BullsPower";
      
      handle[16] = iChaikin(symbol, timeframe, 3, 10, MODE_EMA, VOLUME_TICK);
      buffers[16].names = " Chainkin";
      
      handle[17] = iCCI(symbol, timeframe, 14, PRICE_OPEN);
      buffers[17].names = " CCI"; 
      
      handle[18] = iDeMarker(symbol, timeframe, 14);
      buffers[18].names = " Demarker";
      
      handle[19] = iForce(symbol, timeframe, 13, MODE_SMA, VOLUME_TICK);
      buffers[19].names = " Force";
      
      handle[20] = iMACD(symbol, timeframe, 12, 26, 9, PRICE_OPEN);
      buffers[20].names = " MACD-MAIN_LINE, MACD-SIGNAL_LINE";
      
      handle[21] = iMomentum(symbol, timeframe, 14, PRICE_OPEN);
      buffers[21].names = " Momentum";
      
      handle[22] = iOsMA(symbol, timeframe, 12, 26, 9, PRICE_OPEN);
      buffers[22].names = " OsMA";
      
      handle[23] = iRSI(symbol, timeframe, 14, PRICE_OPEN);
      buffers[23].names = " RSI";
      
      handle[24] = iRVI(symbol, timeframe, 10);
      buffers[24].names = " RVI-MAIN_LINE, RVI-SIGNAL_LINE";
      
      handle[25] = iStochastic(symbol, timeframe, 5, 3,3,MODE_SMA,STO_LOWHIGH);
      buffers[25].names = " Stochastic-MAIN_LINE, Stochastic-SIGNAL_LINE";
      
      handle[26] = iTriX(symbol, timeframe, 14, PRICE_OPEN);
      buffers[26].names = " TEMA";
      
      handle[27] = iWPR(symbol, timeframe, 14);
      buffers[27].names = " WPR";
      
   
//--- Volumes
   
      handle[28] = iAD(symbol, timeframe, VOLUME_TICK);
      buffers[28].names = " AD";
      
      handle[29] = iMFI(symbol, timeframe, 14, VOLUME_TICK);
      buffers[29].names = " MFI";
      
      handle[30] = iOBV(symbol, timeframe, VOLUME_TICK);
      buffers[30].names = " OBV";
      
      handle[31] = iVolumes(symbol, timeframe, VOLUME_TICK);
      buffers[31].names = " Tick-Volumes";
      
   
//--- Bill williams;
      
      handle[32] = iAC(symbol, timeframe);
      buffers[32].names = " AC";
      
      handle[33] = iAlligator(symbol, timeframe, 13, 8,8,5,5,3, MODE_SMMA, PRICE_OPEN);
      buffers[33].names = " Alligator-GATORJAW_LINE, Alligator-GATORTEETH_LINE, Alligator-GATORLIPS_LINE";
      
      handle[34] = iAO(symbol, timeframe);
      buffers[34].names = " AO";
      
      handle[35] = iFractals(symbol, timeframe);
      buffers[35].names = " Fractals-UPPER_LINE, Fractals-LOWER_LINE";
      
      handle[36] = iGator(symbol, timeframe,13,8,8,5,5,3, MODE_SMMA, PRICE_OPEN);
      buffers[36].names = " Gator-UPPER_HISTOGRAM, Gator-LOWER_HISTOGRAM";
      
      handle[37] = iBWMFI(symbol, timeframe, VOLUME_TICK);
      buffers[37].names = " BWMFI";
      
      Print("Indicators total =",ArraySize(handle));
   }
   
   matrix GetAllBuffers(string &buffer_names, int start, int count)
    {
      matrix data = {};
      buffer_info indicator_info;
      
      for (int i=0; i<ArraySize(buffers); i++)
        {
          indicator_info = buffers[i]._CopyRates(handle[i],buffers[i].names, start, count);
          data = CMatrixutils::concatenate(data, indicator_info.values, 1); 
          
          buffer_names+= (buffers[i].names + (i==ArraySize(buffers)-1 ? "":","));
        }       
      return data;
    }
};

input ENUM_TIMEFRAMES tf = PERIOD_CURRENT; //timeframe
MqlRates rates[];
   
Indicators indicators(Symbol(), tf);
string names;

matrix input_matrix;
vector input_vector;
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void OnStart()
  {
//---   
   ArraySetAsSeries(rates, true);

   TrainTestLR(1);
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
vector GetCloseRates(int start_bar, int total)
 {
   CopyRates(Symbol(), PERIOD_CURRENT, start_bar, total, rates);
   
   double targ_Arr[];
   ArrayResize(targ_Arr, total);
   
   for (int i=0; i<total; i++)
     targ_Arr[i] = rates[i].close;
   
   return matrix_utils.ArrayToVector(targ_Arr);;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void TrainTestLR(int start_bar=1)
 {
   matrix data = indicators.GetAllBuffers(names, start_bar, buffer_size);   
   
//--- Getting close values

   vector target = GetCloseRates(start_bar, buffer_size);
      
//--- Dimension reduction process 
   
   switch(dimension_redux)
     {
      case  NMF:
        {
          nmf = new CNMF(nmf_iterations);
         
          uint k = nmf.select_best_components(data);
          //Print("Best k components = ",k);
          data = nmf.fit_transform(data, k);    
        }
        break;
      case TRUNC_SVD:
      
         truncated_svd = new CTruncatedSVD();
         data = truncated_svd.fit_transform(data);  
         
        break;
      case None:
        break;
     }
     
//---

  Print(EnumToString(dimension_redux)," Reduced matrix[",data.Rows(),"x",data.Cols(),"]\n",data);
  
//--- 

   matrix train_x, test_x;
   vector train_y, test_y;
   
   data = matrix_utils.concatenate(data, target); //add the target variable to the dimension reduced data
   
   //Print("Data\n",data);
   
   matrix_utils.TrainTestSplitMatrices(data, train_x, train_y, test_x, test_y, 0.7, 42);
   
   //Print("Corr coeff\n",data.CorrCoef(false));
   
   lr.fit(train_x, train_y, NORM_MIN_MAX_SCALER); //training Linear regression model
   
   vector preds = lr.predict(train_x); //Predicting the training data
   
   Print("Train acc = ",metrics.r_squared(train_y, preds)); //Measuring the performance
   
   preds = lr.predict(test_x); //predicting the test data 
   
   Print("Test acc = ",metrics.r_squared(test_y, preds)); //measuring the performance
   
//---

   if (CheckPointer(truncated_svd)!=POINTER_INVALID)
     delete (truncated_svd);
     
   if (CheckPointer(nmf)!=POINTER_INVALID)
     delete (nmf);
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+